# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

'Pretraining on GPUs.'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from npu_bridge.npu_init import *
import os, sys
import math
import json
import time
import numpy as np
from absl import flags
import absl.logging as _logging
import tensorflow as tf
import data_utils
import model_utils
from gpu_utils import assign_to_gpu, average_grads_and_vars
import function_builder
flags.DEFINE_integer('num_hosts', default=1, help='Number of hosts')
flags.DEFINE_integer('num_core_per_host', default=8, help='Number of cores per host')
flags.DEFINE_bool('use_tpu', default=False, help='Whether to use TPUs for training.')
flags.DEFINE_integer('num_passes', default=1, help='Number of passed used for training.')
flags.DEFINE_string('record_info_dir', default=None, help='Path to local directory containing `record_info-lm.json`.')
flags.DEFINE_string('model_dir', default=None, help='Estimator model_dir.')
flags.DEFINE_string('init_checkpoint', default=None, help='checkpoint path for initializing the model.')
flags.DEFINE_float('learning_rate', default=0.0001, help='Maximum learning rate.')
flags.DEFINE_float('clip', default=1.0, help='Gradient clipping value.')
flags.DEFINE_float('min_lr_ratio', default=0.001, help='Minimum ratio learning rate.')
flags.DEFINE_integer('warmup_steps', default=0, help='Number of steps for linear lr warmup.')
flags.DEFINE_float('adam_epsilon', default=1e-08, help='Adam epsilon')
flags.DEFINE_string('decay_method', default='poly', help='poly or cos')
flags.DEFINE_float('weight_decay', default=0.0, help='weight decay')
flags.DEFINE_integer('train_batch_size', default=16, help='Size of train batch.')
flags.DEFINE_integer('train_steps', default=100000, help='Total number of training steps.')
flags.DEFINE_integer('iterations', default=1000, help='Number of iterations per repeat loop.')
flags.DEFINE_integer('save_steps', default=None, help='number of steps for model checkpointing.')
flags.DEFINE_integer('seq_len', default=0, help='Sequence length for pretraining.')
flags.DEFINE_integer('reuse_len', default=0, help='How many tokens to be reused in the next batch. Could be half of seq_len')
flags.DEFINE_bool('bi_data', default=True, help='Use bidirectional data streams, i.e., forward & backward.')
flags.DEFINE_integer('mask_alpha', default=6, help='How many tokens to form a group.')
flags.DEFINE_integer('mask_beta', default=1, help='How many tokens to mask within each group.')
flags.DEFINE_integer('num_predict', default=None, help='Number of tokens to predict in partial prediction.')
flags.DEFINE_integer('perm_size', default=None, help='perm size.')
flags.DEFINE_bool('uncased', False, help='Use uncased inputs or not.')
flags.DEFINE_integer('n_token', 32000, help='Vocab size')
flags.DEFINE_integer('mem_len', default=0, help='Number of steps to cache')
flags.DEFINE_bool('same_length', default=False, help='Same length attention')
flags.DEFINE_integer('clamp_len', default=(- 1), help='Clamp length')
flags.DEFINE_integer('n_layer', default=6, help='Number of layers.')
flags.DEFINE_integer('d_model', default=32, help='Dimension of the model.')
flags.DEFINE_integer('d_embed', default=32, help='Dimension of the embeddings.')
flags.DEFINE_integer('n_head', default=4, help='Number of attention heads.')
flags.DEFINE_integer('d_head', default=8, help='Dimension of each attention head.')
flags.DEFINE_integer('d_inner', default=32, help='Dimension of inner hidden size in positionwise feed-forward.')
flags.DEFINE_float('dropout', default=0.0, help='Dropout rate.')
flags.DEFINE_float('dropatt', default=0.0, help='Attention dropout rate.')
flags.DEFINE_bool('untie_r', default=False, help='Untie r_w_bias and r_r_bias')
flags.DEFINE_string('summary_type', default='last', help='Method used to summarize a sequence into a compact vector.')
flags.DEFINE_string('ff_activation', default='relu', help='Activation type used in position-wise feed-forward.')
flags.DEFINE_bool('use_bfloat16', False, help='Whether to use bfloat16.')
flags.DEFINE_enum('init', default='normal', enum_values=['normal', 'uniform'], help='Initialization method.')
flags.DEFINE_float('init_std', default=0.02, help='Initialization std when init is normal.')
flags.DEFINE_float('init_range', default=0.1, help='Initialization std when init is uniform.')
FLAGS = flags.FLAGS

def get_model_fn():

    def model_fn(features, labels, mems, is_training):
        (total_loss, new_mems, monitor_dict) = function_builder.get_loss(FLAGS, features, labels, mems, is_training)
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        tf.logging.info('#params: {}'.format(num_params))
        assert is_training
        all_vars = tf.trainable_variables()
        grads = tf.gradients(total_loss, all_vars)
        grads_and_vars = list(zip(grads, all_vars))
        return (total_loss, new_mems, grads_and_vars)
    return model_fn

def single_core_graph(is_training, features, mems):
    model_fn = get_model_fn()
    model_ret = model_fn(features=features, labels=None, mems=mems, is_training=is_training)
    return model_ret

def create_mems_tf(bsz_per_core):
    mems = [tf.placeholder(dtype=tf.float32, shape=[FLAGS.mem_len, bsz_per_core, FLAGS.d_model]) for layer in range(FLAGS.n_layer)]
    return mems

def initialize_mems_np(bsz_per_core):
    mems_np = [np.zeros(shape=[FLAGS.mem_len, bsz_per_core, FLAGS.d_model], dtype=np.float32) for layer in range(FLAGS.n_layer)]
    return mems_np

def train(ps_device):
    (train_input_fn, record_info_dict) = data_utils.get_input_fn(tfrecord_dir=FLAGS.record_info_dir, split='train', bsz_per_host=FLAGS.train_batch_size, seq_len=FLAGS.seq_len, reuse_len=FLAGS.reuse_len, bi_data=FLAGS.bi_data, num_hosts=1, num_core_per_host=1, perm_size=FLAGS.perm_size, mask_alpha=FLAGS.mask_alpha, mask_beta=FLAGS.mask_beta, uncased=FLAGS.uncased, num_passes=FLAGS.num_passes, use_bfloat16=FLAGS.use_bfloat16, num_predict=FLAGS.num_predict)
    tf.logging.info('num of batches {}'.format(record_info_dict['num_batch']))
    bsz_per_core = (FLAGS.train_batch_size // FLAGS.num_core_per_host)
    params = {'batch_size': FLAGS.train_batch_size}
    train_set = train_input_fn(params)
    example = train_set.make_one_shot_iterator().get_next()
    if (FLAGS.num_core_per_host > 1):
        examples = [{} for _ in range(FLAGS.num_core_per_host)]
        for key in example.keys():
            vals = tf.split(example[key], FLAGS.num_core_per_host, 0)
            for device_id in range(FLAGS.num_core_per_host):
                examples[device_id][key] = vals[device_id]
    else:
        examples = [example]
    (tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars) = ([], [], [], [])
    for i in range(FLAGS.num_core_per_host):
        reuse = (True if (i > 0) else None)
        with tf.device('/cpu:0'), tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
            mems_i = {}
            if FLAGS.mem_len:
                mems_i['mems'] = create_mems_tf(bsz_per_core)
            (loss_i, new_mems_i, grads_and_vars_i) = single_core_graph(is_training=True, features=examples[i], mems=mems_i)
            tower_mems.append(mems_i)
            tower_losses.append(loss_i)
            tower_new_mems.append(new_mems_i)
            tower_grads_and_vars.append(grads_and_vars_i)
    if (len(tower_losses) > 1):
        loss = (tf.add_n(tower_losses) / len(tower_losses))
        grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
    else:
        loss = tower_losses[0]
        grads_and_vars = tower_grads_and_vars[0]
    (train_op, learning_rate, gnorm) = model_utils.get_train_op(FLAGS, None, grads_and_vars=grads_and_vars)
    global_step = tf.train.get_global_step()
    tower_mems_np = []
    for i in range(FLAGS.num_core_per_host):
        mems_i_np = {}
        for key in tower_mems[i].keys():
            mems_i_np[key] = initialize_mems_np(bsz_per_core)
        tower_mems_np.append(mems_i_np)
    saver = tf.train.Saver()
    gpu_options = tf.GPUOptions(allow_growth=True)
    model_utils.init_from_checkpoint(FLAGS, global_vars=True)
    with tf.Session(config=npu_config_proto(config_proto=npu_config_proto(config_proto=tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)))) as sess:
        sess.run(tf.global_variables_initializer())
        fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
        (total_loss, prev_step) = (0.0, (- 1))
        while True:
            feed_dict = {}
            for i in range(FLAGS.num_core_per_host):
                for key in tower_mems_np[i].keys():
                    for (m, m_np) in zip(tower_mems[i][key], tower_mems_np[i][key]):
                        feed_dict[m] = m_np
            fetched = sess.run(fetches, feed_dict=feed_dict)
            (loss_np, tower_mems_np, curr_step) = fetched[:3]
            total_loss += loss_np
            if ((curr_step > 0) and ((curr_step % FLAGS.iterations) == 0)):
                curr_loss = (total_loss / (curr_step - prev_step))
                tf.logging.info('[{}] | gnorm {:.2f} lr {:8.6f} | loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}'.format(curr_step, fetched[(- 3)], fetched[(- 2)], curr_loss, math.exp(curr_loss), (curr_loss / math.log(2))))
                (total_loss, prev_step) = (0.0, curr_step)
            if ((curr_step > 0) and ((curr_step % FLAGS.save_steps) == 0)):
                save_path = os.path.join(FLAGS.model_dir, 'model.ckpt')
                saver.save(sess, save_path)
                tf.logging.info('Model saved in path: {}'.format(save_path))
            if (curr_step >= FLAGS.train_steps):
                break

def main(unused_argv):
    del unused_argv
    tf.logging.set_verbosity(tf.logging.INFO)
    FLAGS.n_token = data_utils.VOCAB_SIZE
    tf.logging.info('n_token {}'.format(FLAGS.n_token))
    if (not tf.gfile.Exists(FLAGS.model_dir)):
        tf.gfile.MakeDirs(FLAGS.model_dir)
    train('/gpu:0')
if (__name__ == '__main__'):
    tf.app.run()
